package com.idanchuang.component.redis.util;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.script.DefaultRedisScript;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 限流工具类
 *
 * @author yjy
 * @date 2020/9/3 18:21
 **/
public class RedisLimits {

    private static final Logger log = LoggerFactory.getLogger(RedisLimits.class);

    /** 默认预配额因子, 默认预配额为 总配额的 1%  (此值不能大于0.1) */
    private static double defaultFactor = 0.01D;

    /** 申请失败缓冲时间(当从redis申请配额失败后, 在一段时间内不再申请, 直接失败), 单位: 毫秒 */
    private static long bufferTime = 10L;

    /** redis key前缀 */
    private static final String REDIS_KEY_PREFIX = "limit:";

    /** 资源对应的预配额因子, 不设置则取 默认预配额因子 */
    private static final Map<String, Double> FACTOR_MAP = new HashMap<>();

    /** 本地预配额 */
    private static final Map<String, AtomicInteger> PRE_REQUIRE = new ConcurrentHashMap<>();

    /** 最近申请失败的记录, 资源名 -> 失败时间 */
    private static final Map<String, Long> LAST_FAILED = new ConcurrentHashMap<>();

    /**
     * 设置默认的预配额因子
     * @param defaultFactor 因子, 不能大于 0.1
     */
    public static void setDefaultFactor(double defaultFactor) {
        RedisLimits.defaultFactor = Math.min(defaultFactor, 0.1D);
    }

    /**
     * 设置指定资源的预配额因子
     * @param name 资源名
     * @param factor 因子, 不能大于 0.1
     */
    public static void setFactor(String name, double factor) {
        FACTOR_MAP.put(name, Math.min(factor, 0.1D));
    }

    /**
     * 申请失败缓冲时间
     * @param bufferTime 缓冲时间
     */
    public static void setBufferTime(long bufferTime) {
        RedisLimits.bufferTime = Math.max(bufferTime, 1L);
    }

    public static boolean require(String name, int limit) {
        return require(name, limit, 1, 1, TimeUnit.SECONDS);
    }

    public static boolean require(String name, int limit, int require) {
        return require(name, limit, require, 1, TimeUnit.SECONDS);
    }

    public static boolean require(String name, int limit, int timeWindow, TimeUnit timeUnit) {
        return require(name, limit, 1, timeWindow, timeUnit);
    }

    /**
     * 限流访问申请
     * @param name 限流资源名
     * @param limit 限流额度
     * @param require 申请额度, 默认: 1
     * @param timeWindow 时间窗口值, 默认: 1
     * @param timeUnit 时间窗口单位, 默认: 秒
     * @return 是否通过
     */
    public static boolean require(String name, int limit, int require, int timeWindow, TimeUnit timeUnit) {
        if (require < 1) {
            return true;
        }
        if (require > limit) {
            return false;
        }
        // 预判失败
        if (predictFailed(name)) {
            log.debug("require predictFailed, name: {}", name);
            return false;
        }
        // 尝试申请本地预配额
        if (requireLocal(name, require)) {
            log.debug("require requireLocal success, name: {}, require: {}", name, require);
            return true;
        }
        try {
            // 假如时间窗口配置大于1s, 则进行分割
            long secondWindow = timeUnit.toSeconds(timeWindow);
            secondWindow = Math.max(secondWindow, 1);
            // 计算每秒限流数
            int secondLimit = (int) (limit / secondWindow);
            secondLimit = Math.max(secondLimit, 1);
            // 计算每秒预配额
            int preRequire = (int)(limit * FACTOR_MAP.getOrDefault(name, defaultFactor));
            int secondPreRequire = (int) (preRequire / secondWindow);
            // 如果申请额度小于每秒预配额, 那么进行预申请
            if (secondPreRequire > 0 && secondPreRequire > require) {
                // 申请预配额
                if (requireRedis(name, secondLimit, secondPreRequire)) {
                    // 更新本地预配额
                    updateLocal(name, secondPreRequire - require);
                    return true;
                }
            }
            // 直接从redis申请配额
            return requireRedis(name, secondLimit, require);
        } catch (Exception e) {
            log.error("require error", e);
            // 降级通过
            return true;
        }
    }

    /**
     * 在本地申请配额
     * @param name 资源名
     * @param require 申请配额数量
     * @return 是否申请成功
     */
    private static boolean requireLocal(String name, int require) {
        AtomicInteger preRequire = getLocal(name);
        if (preRequire.get() < require) {
            return false;
        }
        updateLocal(name, require * -1);
        return true;
    }

    /**
     * 更新本地配额
     * @param name 资源名
     * @param updateRequire 更新数量
     * @return 更新后的值
     */
    private static AtomicInteger updateLocal(String name, int updateRequire) {
        AtomicInteger local = getLocal(name);
        local.addAndGet(updateRequire);
        return local;
    }

    /**
     * 初始化本地配额
     * @param name 资源名
     * @return 本地配额
     */
    private static AtomicInteger getLocal(String name) {
        if (!PRE_REQUIRE.containsKey(name)) {
            PRE_REQUIRE.put(name, new AtomicInteger(0));
        }
        return PRE_REQUIRE.get(name);
    }

    /**
     * 从Redis申请配额
     * @param name 限流资源名
     * @param limit 限流额度
     * @param require 申请额度
     * @return 是否成功
     */
    private static boolean requireRedis(String name, int limit, int require) {
        log.debug("requireRedis name: {}, limit: {}, require: {}, timeWindow: {}, timeUnit: {}",
                name, limit, require, 1, TimeUnit.SECONDS);
        // 申请资源
        if (doRequire(name, limit, require)) {
            // 清除失败记录
            clearLastFailed(name);
            return true;
        }
        // 记录失败时间
        setLastFailed(name);
        return false;
    }

    private static boolean doRequire(String name, int limit, int require) {
        String key = getKey(name);
        long millisTtl = 1000;
        String script = "local apply,limit,ttl=ARGV[1],ARGV[2],ARGV[3] ";
        script += "if redis.call('EXISTS', KEYS[1])==1 " +
                    "then local curValue = redis.call('GET', KEYS[1]) " +
                    "local targetValue = curValue + apply " +
                    "if targetValue <= tonumber(limit) " +
                        "then redis.call('INCRBY', KEYS[1], apply) " +
                        "return true " +
                    "else " +
                        "return false " +
                    "end " +
                "else " +
                    "redis.call('SET', KEYS[1], apply) " +
                    "redis.call('PEXPIRE', KEYS[1], ttl) " +
                    "return true " +
                "end";
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Boolean.class);
        redisScript.setScriptText(script);
        List<String> keys = new ArrayList<>();
        keys.add(key);
        return RedisUtil.getInstance().execute(redisScript, keys,
                String.valueOf(require), String.valueOf(limit), String.valueOf(millisTtl));
    }

    /**
     * 预判本次申请会失败
     * @param name 资源名
     * @return 是否会失败
     */
    private static boolean predictFailed(String name) {
        long last = LAST_FAILED.getOrDefault(name, 0L);
        return System.currentTimeMillis() < last + bufferTime;
    }

    /**
     * 更新失败时间
     * @param name 资源名
     */
    private static void setLastFailed(String name) {
        LAST_FAILED.put(name, System.currentTimeMillis());
    }

    /**
     * 清除失败时间
     * @param name 资源名
     */
    private static void clearLastFailed(String name) {
        LAST_FAILED.remove(name);
    }

    /**
     * @param name 资源名
     * @return redis中的资源key
     */
    private static String getKey(String name) {
        return REDIS_KEY_PREFIX + name;
    }

}
